{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image captioning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "画像のキャプション付けは、特定の画像のキャプションを予測するタスクです。一般的な現実世界のアプリケーションには次のものがあります。\n",
    "視覚障害者がさまざまな状況を乗り越えられるよう支援します。したがって、画像のキャプション\n",
    "画像を説明することで人々のコンテンツへのアクセシビリティを向上させるのに役立ちます。\n",
    "\n",
    "このガイドでは、次の方法を説明します。\n",
    "\n",
    "* 画像キャプション モデルを微調整します。\n",
    "* 微調整されたモデルを推論に使用します。\n",
    "\n",
    "始める前に、必要なライブラリがすべてインストールされていることを確認してください。\n",
    "\n",
    "```bash\n",
    "pip install transformers datasets evaluate -q\n",
    "pip install jiwer -q\n",
    "```\n",
    "\n",
    "モデルをアップロードしてコミュニティと共有できるように、Hugging Face アカウントにログインすることをお勧めします。プロンプトが表示されたら、トークンを入力してログインします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the Pokémon BLIP captions dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "🤗 データセット ライブラリを使用して、{image-caption} ペアで構成されるデータセットを読み込みます。独自の画像キャプション データセットを作成するには\n",
    "PyTorch では、[このノートブック](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/GIT/Fine_tune_GIT_on_an_image_captioning_dataset.ipynb) を参照できます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset(\"lambdalabs/pokemon-blip-captions\")\n",
    "ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "DatasetDict({\n",
    "    train: Dataset({\n",
    "        features: ['image', 'text'],\n",
    "        num_rows: 833\n",
    "    })\n",
    "})\n",
    "```\n",
    "\n",
    "データセットには `image`と`text`の 2 つの機能があります。\n",
    "\n",
    "<Tip>\n",
    "\n",
    "多くの画像キャプション データセットには、画像ごとに複数のキャプションが含まれています。このような場合、一般的な戦略は、トレーニング中に利用可能なキャプションの中からランダムにキャプションをサンプリングすることです。\n",
    "\n",
    "</Tip>\n",
    "\n",
    "`train_test_split` メソッドを使用して、データセットのトレイン スプリットをトレイン セットとテスト セットに分割します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = ds[\"train\"].train_test_split(test_size=0.1)\n",
    "train_ds = ds[\"train\"]\n",
    "test_ds = ds[\"test\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニング セットからのいくつかのサンプルを視覚化してみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from textwrap import wrap\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def plot_images(images, captions):\n",
    "    plt.figure(figsize=(20, 20))\n",
    "    for i in range(len(images)):\n",
    "        ax = plt.subplot(1, len(images), i + 1)\n",
    "        caption = captions[i]\n",
    "        caption = \"\\n\".join(wrap(caption, 12))\n",
    "        plt.title(caption)\n",
    "        plt.imshow(images[i])\n",
    "        plt.axis(\"off\")\n",
    "\n",
    "\n",
    "sample_images_to_visualize = [np.array(train_ds[i][\"image\"]) for i in range(5)]\n",
    "sample_captions = [train_ds[i][\"text\"] for i in range(5)]\n",
    "plot_images(sample_images_to_visualize, sample_captions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_training_images_image_cap.png\" alt=\"Sample training images\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "データセットには 2 つのモダリティ (画像とテキスト) があるため、前処理パイプラインは画像とキャプションを前処理します。\n",
    "\n",
    "これを行うには、微調整しようとしているモデルに関連付けられたプロセッサ クラスをロードします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor\n",
    "\n",
    "checkpoint = \"microsoft/git-base\"\n",
    "processor = AutoProcessor.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "プロセッサは内部で画像を前処理し (サイズ変更やピクセル スケーリングを含む)、キャプションをトークン化します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transforms(example_batch):\n",
    "    images = [x for x in example_batch[\"image\"]]\n",
    "    captions = [x for x in example_batch[\"text\"]]\n",
    "    inputs = processor(images=images, text=captions, padding=\"max_length\")\n",
    "    inputs.update({\"labels\": inputs[\"input_ids\"]})\n",
    "    return inputs\n",
    "\n",
    "\n",
    "train_ds.set_transform(transforms)\n",
    "test_ds.set_transform(transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "データセットの準備ができたら、微調整用にモデルをセットアップできます。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load a base model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[\"microsoft/git-base\"](https://huggingface.co/microsoft/git-base) を [`AutoModelForCausalLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM) オブジェクト。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "画像キャプション モデルは通常、[Rouge Score](https://huggingface.co/spaces/evaluate-metric/rouge) または [Word Error Rate](https://huggingface.co/spaces/evaluate-metric/) で評価されます。そうだった）。このガイドでは、Word Error Rate (WER) を使用します。\n",
    "\n",
    "これを行うには 🤗 Evaluate ライブラリを使用します。 WER の潜在的な制限やその他の問題点については、[このガイド](https://huggingface.co/spaces/evaluate-metric/wer) を参照してください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from evaluate import load\n",
    "import torch\n",
    "\n",
    "wer = load(\"wer\")\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predicted = logits.argmax(-1)\n",
    "    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)\n",
    "    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)\n",
    "    wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels)\n",
    "    return {\"wer_score\": wer_score}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "これで、モデルの微調整を開始する準備が整いました。これには 🤗 [Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) を使用します。\n",
    "\n",
    "まず、[TrainingArguments](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.TrainingArguments) を使用してトレーニング引数を定義します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "model_name = checkpoint.split(\"/\")[1]\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=f\"{model_name}-pokemon\",\n",
    "    learning_rate=5e-5,\n",
    "    num_train_epochs=50,\n",
    "    fp16=True,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    gradient_accumulation_steps=2,\n",
    "    save_total_limit=3,\n",
    "    eval_strategy=\"steps\",\n",
    "    eval_steps=50,\n",
    "    save_strategy=\"steps\",\n",
    "    save_steps=50,\n",
    "    logging_steps=50,\n",
    "    remove_unused_columns=False,\n",
    "    push_to_hub=True,\n",
    "    label_names=[\"labels\"],\n",
    "    load_best_model_at_end=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Trainer 次に、次に、データセットとモデルと一緒に 🤗 に渡します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_ds,\n",
    "    eval_dataset=test_ds,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニングを開始するには、[Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) オブジェクトの [train()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.train) を呼び出すだけです。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニングが進むにつれて、トレーニングの損失がスムーズに減少することがわかります。\n",
    "\n",
    "トレーニングが完了したら、 [push_to_hub()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.push_to_hub) メソッドを使用してモデルをハブに共有し、誰もがモデルを使用できるようにします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`test_ds` からサンプル画像を取得してモデルをテストします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "\n",
    "url = \"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png\"\n",
    "image = Image.open(requests.get(url, stream=True).raw)\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/test_image_image_cap.png\" alt=\"Test image\"/>\n",
    "</div>\n",
    "\n",
    "モデル用の画像を準備します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "inputs = processor(images=image, return_tensors=\"pt\").to(device)\n",
    "pixel_values = inputs.pixel_values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`generate` を呼び出して予測をデコードします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_ids = model.generate(pixel_values=pixel_values, max_length=50)\n",
    "generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
    "print(generated_caption)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "a drawing of a pink and blue pokemon\n",
    "```\n",
    "\n",
    "微調整されたモデルにより、非常に優れたキャプションが生成されたようです。"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
